import os
import sys
import math
import argparse
from datetime import datetime
from torchvision import models as torchvision_models

from data.dl_getter import DATASETS, n_cls, sh, input_range
from tool.util import set_seed, bool_flag


torchvision_archs = sorted(name for name in torchvision_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(torchvision_models.__dict__[name]))


def get_general_args(parser = None, is_nb=False):
    if parser is None :
        parser = argparse.ArgumentParser('Evaluation')

    # parser.add_argument('--learn_type', default=0, type=int, help="0: standard, 1: finetune")
    parser.add_argument('--proj_name', default='econ', type=str, help='experiment name')
    parser.add_argument('--exp', default='', type=str, help='experiment name')
    parser.add_argument('--tags', type=str, default='', metavar='N',
                        help='')
    parser.add_argument('--output_dir', default=os.path.expanduser("~/exp.log/"),
                        help='Path to save logs and chkpts')
    parser.add_argument('--exp_load', default='', type=str, help='experiment name')
    parser.add_argument('--wandb_dir', default=os.path.expanduser("~/wandb.log/"),
                        help='Path to save logs and chkpts')
    parser.add_argument('--wandb_entity', default="ml_research",
                        help='wandb entity')
    parser.add_argument('--debug',   action='store_true', default=False,
                        help='debug mode or not')
    # method
    parser.add_argument('--method', default='ce', type=str, choices=['ce',
                        'ce_ec', 'ce_ech', 'econ', 'supcon', 'simclr',
                        'finetune', 'ebm', 'jem', 'sadajem', 'sadajem_ec',
                        'evaluate',], help="method name")
    parser.add_argument('--mixup_criterion', default=True, type=bool_flag, help="loss mixup")
    parser.add_argument('--resume', action='store_true', default=False, help="resume")
    parser.add_argument('--arch', default='resnet34', type=str, help='Architecture')
    parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
    parser.add_argument("--chkpt_key", default="teacher", type=str, help='Key to use in the chkpt (example: "teacher")')
    parser.add_argument('--epochs', default=200, type=int, help='Number of epochs of training.')
    parser.add_argument("--optimizer", choices=["adam", "sgd"], default="sgd")
    parser.add_argument("--lr", default=0.1, type=float, help="""Learning rate at the beginning of
        training (highest LR used during training). The learning rate is linearly scaled
        with the batch size, and specified here for a reference batch size of 256.
        We recommend tweaking the LR depending on the chkpt evaluated.""") # sadajem .1
    parser.add_argument("--lr_decay_epochs", nargs="+", type=int, default=[60, 120, 160],
                        #default=[160, 180], jem
                        #default=[60, 120, 180], sadajem
                        help="decay learning rate by decay_rate at these epochs")
    parser.add_argument("--lr_decay_rate", type=float, default=.2, #sadajem .2
                        help="learning rate decay multiplier")
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    parser.add_argument('--bsz', default=128, type=int, help='batch-size') ## sadajem ebm: 64, cls: 128
    parser.add_argument('--bsz_vl', default=256, type=int, help='batch-size')
    parser.add_argument("--warmup_iters", type=int, default=1000,
                        help="number of iters to linearly increase learning rate, if -1 then no warmmup")
    parser.add_argument('--reload', default=True, type=bool_flag,
                        help='restart when encounter a large loss')

    # sadajem
    parser.add_argument("--l2_coeff", type=float, default=0)

    # supcon
    parser.add_argument('--cosine', action='store_true', default=False, help='using cosine annealing')
    parser.add_argument('--warm', action='store_true', help='warm-up for large batch training')
    parser.add_argument('--temp', type=float, default=0.10,
                        help='temperature for loss function')
    parser.add_argument('--cropsize', type=int, default=32, help='parameter for RandomResizedCrop')

    parser.add_argument('--x_noise', action='store_true', default=False,
                        help="noise")
    parser.add_argument('--bs', action='store_true', default=False,
                        help="balanced sampling")


    parser.add_argument("--ball_sz", type=float, default=1e-1)
    parser.add_argument("--ball_dec", action='store_true', default=False, help="")
    parser.add_argument("--l_norm", choices=["2", "infty"], default="2")

    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
        distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
    parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
    parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
    parser.add_argument('--dataset', type=str, default='cifar10', choices=DATASETS,
                    help='The dataset to use for training)')
    # for evaluation of robustness
    parser.add_argument('--attack', default='pgd', type=str,
                        help='Attack type (CW requires FoolBox)',
                        choices=('pgd', 'cw', 'ifgsm_7', 'ifgsm_10', 'ifgsm_20', 'ifgsm_20',
                                 'pgd_7', 'pgd_10', 'pgd_20', 'pgd_40', "l2_basic", 'aa'))
    parser.add_argument('--epsilon', default=0.031, type=float,
                        help='Attack perturbation magnitude')
    parser.add_argument('--num_steps', default=40, type=int,
                        help='Number of PGD steps')
    parser.add_argument('--step_size', default=0.007, type=float,
                        help='PGD step size')
    parser.add_argument('--num_restarts', default=1, type=int, ## 5: default
                        help='Number of restarts for PGD attack')
    parser.add_argument('--no_random_start', dest='random_start',
                        action='store_false',
                        help='Disable random PGD initialization')
    parser.add_argument('--binary_search_steps', default=5, type=int,

                        help='Number of binary search steps for CW attack')
    parser.add_argument('--max_iterations', default=1000, type=int,
                        help='Max number of Adam iterations in each CW'
                             ' optimization')
    # parser.add_argument('--learning_rate', default=5E-3, type=float,
    #                     help='Learning rate for CW attack')
    parser.add_argument('--initial_const', default=1E-2, type=float,
                        help='Initial constant for CW attack')
    parser.add_argument('--tau_decrease_factor', default=0.9, type=float,
                        help='Tau decrease factor for CW attack')
    parser.add_argument('--random_seed', default=-1, type=int,
                        help='Random seed for permutation of test instances')
    parser.add_argument('--num_eval_batches', default=None, type=int,
                        help='Number of batches to run evalaution on')
    parser.add_argument('--save_freq', default=50, type=int,
                        help='Save chkpt every x epochs.')
    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--n_eg', default=5, type=int,
                        help='the num of eg. for analysis')

    parser.add_argument('--alpha', default=1.0, type=float,
                        help='hyper param for beta distribution in mixup')
    parser.add_argument('--beta', default=2.0, type=float,
                        help='hyper param for beta distribution in mixup')
    parser.add_argument('--extrapolate',   action='store_true', default=False,
                        help='debug mode or not')
    parser.add_argument('--ablation',   action='store_true', default=False,
                        help='stard+mef')

    parser.add_argument('--eval_robust', default=False, type=bool_flag,
                        help="adv robustness")

    parser.add_argument('--use_ma', default=False, action="store_true",
                        help="")
    parser.add_argument('--coef_ma', default=0.1, type=float,
                        help='ma coef')

    parser.add_argument('--analysis_of', type=str, default="all",
                        choices=["calibration","logits","robust", "ece", "ece_HM", "ece_HM_gap", "ece_gap", "gen",
                                 "cgen", "logp_hist", "OOD", "test_clf"],
                        help='')
    parser.add_argument('--overwrite', action="store_true",
                        default=False, help='')
    parser.add_argument('--cw', default=False)


    # EBM specific
    parser.add_argument("--uncond", action="store_true")
    parser.add_argument("--buffer_size", type=int, default=10000)
    parser.add_argument("--reinit_freq", type=float, default=.05)

    parser.add_argument("--sgld_n_steps", type=int, default=5)
    parser.add_argument("--sgld_lr", type=float, default=1)
    parser.add_argument("--sgld_std", type=float, default=0)
    parser.add_argument("--sgld_clip", default=True, type=bool_flag,
                        help="")
    parser.add_argument('--sgld_pre_noise', default=False, type=bool_flag,
                        help="noise annealing")
    # logging + evaluation
    parser.add_argument('--img_dir', default=os.path.expanduser("~/exp.img/"),
                        help='Path to save generated imgs')
    parser.add_argument("--print_every", type=int, default=100)
    parser.add_argument("--n_sample_steps", type=int, default=100000)

    # Toy
    parser.add_argument("--toy_n_samples", type=int, default=50000)

    # gen
    parser.add_argument('--eval_gen', default=False, type=bool_flag,
                        help="")

    # jem + @
    parser.add_argument('--gen_aug', action='store_true', default=False,
                        help="augmentation while generating img by sgld steps")
    parser.add_argument('--anneal', action='store_true', default=False,
                        help="noise annealing")
    parser.add_argument('--in_norm', action='store_true', default=False,
                        help="")
    parser.add_argument('--prof', action='store_true', default=False,
                        help="flag for profile")

    # inex
    parser.add_argument("--inex_proj", choices=["", "1", "2", "inf", "norm"], default="")
    parser.add_argument("--inex_pre_proj", action='store_true', default=False, help="")
    parser.add_argument("--inex_post_proj", action='store_true', default=False, help="")
    parser.add_argument("--inex_noise", action='store_true', default=False, help="")
    parser.add_argument("--inex_no_all", action='store_true', default=False, help="")
    parser.add_argument("--inex_nphase1", type=int, default=2, help="")
    parser.add_argument("--inex_nbreak", type=int, default=2, help="")
    parser.add_argument("--inex_resz", type=int, default=0, help="")
    parser.add_argument("--inex_linsp", action='store_true', default=False, help="")


    parser.add_argument("--i_sample_f", type=int, default=0,
                        help="0: sample from buffer, 1: sample noise, 2: sample data")
    parser.add_argument('--no_tc',   action='store_true', default=False,
                        help='no teacher network')

    # ce energy projection head
    parser.add_argument("--head", choices=["lin", "cos", "rc", "rcn", "ep"], default="ep")
    parser.add_argument("--head_relu", action='store_true', default=False, help="")
    parser.add_argument("--no_head_norm", action='store_true', default=False, help="")
    parser.add_argument("--no_head_sq", action='store_true', default=False, help="")
    parser.add_argument("--head_lin", action='store_true', default=False, help="")
    parser.add_argument("--ms_share", action='store_true', default=False, help="")
    parser.add_argument("--n_headd_l", type=int, default=1, help="")
    parser.add_argument("--orth_reg", action='store_true', default=False, help="")
    parser.add_argument("--i_act_u", type=int, default=0)
    parser.add_argument("--i_ts_act_u", type=int, default=-1, help="")
    parser.add_argument("--feat_dim", type=int, default=64, help="")

    parser.add_argument("--head_post_amp", action='store_true', default=False, help="num non lin")
    parser.add_argument("--head_skip_r", action='store_true', default=False, help="num non lin")
    parser.add_argument("--amp_a1", type=float, default=1.)
    parser.add_argument("--amp_a2", type=float, default=1.)
    parser.add_argument("--head_ms_norm", action='store_true', default=True, help="num non lin")
    parser.add_argument("--headn_nl", action='store_true', default=False, help="num non lin")
    parser.add_argument("--headn_g", action='store_true', default=False, help="group")
    parser.add_argument("--n_headn_l", type=int, default=1, help="")
    parser.add_argument("--headn_b", action='store_true', default=False, help="bias")
    parser.add_argument("--head_post_n", action='store_true', default=False, help="pre_normalize")

    parser.add_argument("--head_dim_rd", action='store_true', default=False, help="dim reduction")

    parser.add_argument("--headd_b", action='store_true', default=True, help="bias")
    parser.add_argument("--head_eval_clip", action='store_true', default=False, help="num non lin")

    parser.add_argument("--cos_sq", action='store_true', default=False, help="bias")
    parser.add_argument("--head_alpha", type=float, default=1.)

    parser.add_argument("--kernel_norm", action='store_true', default=False, help="bias")
    parser.add_argument("--kernel_diff", action='store_true', default=False, help="bias")

    parser.add_argument("--diff_hparam", type=float, default=1.)
    parser.add_argument("--f_max", action='store_true', default=False, help="bias")

    parser.add_argument("--attn_hier", action='store_true', default=False, help="bias")
    parser.add_argument("--attn_dim", type=int, default=0)

    # head combination
    parser.add_argument("--d0", action='store_true', default=False, help="bias")
    parser.add_argument("--d01", action='store_true', default=False, help="bias")
    parser.add_argument("--cl_max", action='store_true', default=False, help="bias")
    parser.add_argument("--cl_mean", action='store_true', default=False, help="bias")
    parser.add_argument("--cl_exp", action='store_true', default=False, help="bias")

    parser.add_argument("--mcog", action='store_true', default=True, help="bias")
    parser.add_argument("--mcog_hier", action='store_true', default=False, help="bias")
    # parser.add_argument("--gate", action='store_true', default=False, help="bias")
    # parser.add_argument("--gate_b", type=float, default=3.)


    # ec
    parser.add_argument("--ec", action='store_true', default=False, help="num non lin")
    parser.add_argument("--i_pos_fn", type=int, default=0, help="0~2")
    parser.add_argument("--i_neg_fn", type=int, default=0, help="0,1")


    # sam
    parser.add_argument("--rho", type=float, default=2)

    # info init
    parser.add_argument("--init", type=str, default='gm', help='gm: Gaussian Mixture,  u: uniform')

    # fidis
    parser.add_argument("--every_fidis", type=int, default=10,
                        help="")
    parser.add_argument("--auc_epoch", type=int, default=200)

    # landscape
    parser.add_argument("--landscape", type=str, default="images", choices=['images', 'parameters'])
    parser.add_argument("--landscape_step", type=int, default=40)
    parser.add_argument("--landscape_range", type=float, default=0.1)
    parser.add_argument("--landscape_no_record", type=bool, default=False)

    # probability tables
    parser.add_argument("--no_get_table", action="store_false", default=True, help="id,ood probability table")
    parser.add_argument("--amp_dim", type=int, default=1)
    parser.add_argument("--amp_select", type=str, default='top1')

    parser.add_argument("--eval", action='store_true', default=False, help="bias")

    # analysis
    parser.add_argument("--fwd_result", action="store_true", default=False, help="id,ood probability table")
    if is_nb:
        args = parser.parse_args(args=[])
    else:
        args = parser.parse_args()

    args.upload_dir = args.exp
    dt = args.dataset
    if dt in ['cifar10', 'svhn', 'stl10', 'cifar10H', 'cifar100', 'imagenet']:
        args.in_ch = 3
    else: #if dt == "mnist":
        args.in_ch = 1
    args.sh = sh[args.dataset]
    args.num_labels = args.n_cls = n_cls[args.dataset]
    args.input_range = input_range[args.dataset]

    args.reloaded = False

    # mcog
    if args.mcog_hier: args.mcog = True
    if args.mcog and args.mcog_hier: args.method = 'ce_ech'
    mth = args.method
    args.crl = True if mth in ['supcon', 'econ', 'simclr'] else False
    args.ebm = True if mth in ['ebm', 'jem', 'sadajem'] else False
    args.uc_dl = True if mth in ['jem', 'sadajem'] else False
    args.sam = True if 'sa' in mth else False
    args.ech = True if 'ech' in mth else False
    if args.crl:
        args.head = 'con'
        args.lr = 0.5
        args.epochs = 1000
        args.bsz = 1024
        args.cosine = True
        lr_decay_epochs = [700, 800, 900]
        lr_decay_rate = 0.1
        weight_decay = 1e-4

        # from supcon
        # warm-up for large-batch training,
        if args.bsz > 256:
            args.warm = True
        if args.warm:
            args.warmup_from = 0.01
            args.warm_epochs = 10
            if args.cosine:
                eta_min = args.lr * (args.lr_decay_rate ** 3)
                args.warmup_to = eta_min + (args.lr - eta_min) * (
                        1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
            else:
                args.warmup_to = args.lr

    # args.eta_r = math.exp((math.log(1) - math.log(100))/ args.eta_dur) if args.eta_dur > 0. else 1.
    test_cond = 'test' not in args.exp  and 'tmp' not in args.exp and not args.resume
    if args.proj_name == 'inex':
        spstr = '_linsp' if args.inex_linsp else ''
        args.exp = args.dataset + '_'
        args.exp += str(int((args.bsz//9) * 5)) + 'pair_'
        args.exp += args.inex_proj + spstr + '_' + str(args.inex_resz)
    elif test_cond and not args.evaluate: # and not 'analysis_of' in args:
        a = sys.argv[1:]
        s = ''
        # for k, v in zip(a[::2], a[1::2]):
        #     s += '{}:{},'.format(k.replace('--', ''), v.replace('--', ''))
        for i, ai in enumerate(a):
            if '--' in ai:
                if i > 0: s += ','
                s += ai.replace('--', '')
            else:
                s += ':' + ai
        exp_name = s + '_' + datetime.today().strftime("%y%m%d_%H%M_%S")
        args.exp = args.exp + '_' + exp_name if len(args.exp) > 0 else exp_name
    if args.debug:
        args.epochs = 20

    args.start_time = datetime.today().strftime("%y%m%d_%H%M")
    args.load_path = None
    args.gpu_id = [3]
    set_seed(args.random_seed)
    return args


